Skip to content

Conversation

@hubertlu-tw
Copy link
Contributor

@hubertlu-tw hubertlu-tw commented Jan 27, 2026

Description

This PR fixes ROCm FP8 handling across gfx942/gfx950 by selecting the correct
FP8 variants at runtime and making MFMA/codegen recognize the FP8 dtypes used
by ROCm. It also consolidates FP8 selection into shared helpers so examples and
tests stay consistent across devices.

Key changes

  • Add select_fp8_e4m3_dtype() / select_fp8_e5m2_dtype() helpers and Torch variants.
  • Route ROCm E5M2 through BF8 MFMA intrinsics and add missing MFMA dtype mappings.
  • Fix FP8 E4M3/E5M2 conversions and vector wrappers in HIP templates for gfx950.
  • Update FP8 examples to use shared selection logic and ROCm-friendly paths.
  • Make ROCm FP8 tilelibrary tests select per-GPU dtype instead of hardcoding FNUZ.

Tests

  • pytest -q testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py
  • python /opt/tilelang/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py
  • python /opt/tilelang/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py
  • python /opt/tilelang/examples/gemm_fp8/regression_example_gemm_fp8.py

CC: @Gongen-Ali

Summary by CodeRabbit

  • New Features

    • Automatic FP8 dtype selection adapts to platform and runtime, including additional FP8/BF8 variants.
    • Expanded FP8 type wrappers and storage forms to improve AMD/HIP interoperability.
  • Refactor

    • Examples and tools updated to use dynamic FP8 selection for cross‑platform consistency.
  • Tests

    • GEMM tests updated to use runtime-selected FP8 dtypes for broader coverage and accuracy.

✏️ Tip: You can customize this high-level summary in your review settings.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 27, 2026

📝 Walkthrough

Walkthrough

Introduces runtime FP8 dtype selection utilities and replaces hardcoded FP8 dtypes across examples, tests, and kernels; refactors HIP FP8 type wrappers; and extends MFMA intrinsic generator to recognize additional FP8/BF8 variants. Changes are limited to dtype selection, wrappers, and generator mapping.

Changes

Cohort / File(s) Summary
FP8 Runtime Selectors
tilelang/utils/target.py, tilelang/utils/__init__.py
Added select_fp8_e4m3_dtype, select_torch_fp8_e4m3_dtype, select_fp8_e5m2_dtype, select_torch_fp8_e5m2_dtype and re-exported them for platform-aware FP8 dtype selection and torch.dtype mapping.
Examples — GEMM FP8 (runtime dtypes)
examples/gemm_fp8/example_tilelang_gemm_amd.py, examples/gemm_fp8/example_tilelang_gemm_fp8.py, examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py, examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py, examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py
Replaced hardcoded FP8 dtypes with runtime selection calls; changed tl_matmul default in_dtype to None with runtime fallback; selected emitter (MFMA vs MMA) and micro-tile/thread sizing now tied to chosen emitter in intrinsic example.
HIP FP8 Type System
src/tl_templates/hip/hip_fp8.h
Reworked FP8 typedefs to HIP-specific aliases and added explicit wrapper structs and storage aliases (E4/E5 wrappers, aligned vector wrappers) with constructors/conversions for interoperability.
MFMA Intrinsic Generator
tilelang/intrinsics/mfma_macro_generator.py
Added recognition/abbreviation for float8_e4m3fn (fp8) and BF8 paths; extended k-dim initialization and MFMA suffix logic to accommodate new FP8/BF8 variants.
HIP Codegen dtype map
src/target/codegen_hip.cc
Added MFMA dtype mappings for float8_e5m2x4 and float8_e5m2x8.
Tests updated for runtime dtypes
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py
Replaced hardcoded FP8 dtype literals in test parametrizations with calls to select_fp8_e4m3_dtype() / select_fp8_e5m2_dtype() for platform-aware testing.

Sequence Diagram

sequenceDiagram
    participant App as Application / Example
    participant SelStr as select_fp8_e4m3_dtype()
    participant Platform as Platform Detector
    participant SelTorch as select_torch_fp8_e4m3_dtype()
    participant Torch as PyTorch

    App->>SelTorch: request torch.dtype for FP8
    SelTorch->>SelStr: request FP8 dtype string
    SelStr->>Platform: detect target (CUDA / ROCm / gfx)
    Platform-->>SelStr: platform info
    SelStr-->>SelTorch: return dtype string (e.g., "float8_e4m3fn" / "float8_e4m3fnuz")
    SelTorch->>Torch: map string -> torch.dtype
    Torch-->>SelTorch: return torch.dtype
    SelTorch-->>App: provide runtime torch.dtype for kernels/tests
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested reviewers

  • Gongen-Ali
  • LeiWang1999

Poem

🐇 I hop through types both thin and wide,
Choosing FP8 by platform-side.
CUDA, ROCm — I sniff the breeze,
Pick the dtype with practiced ease.
Wrapped and mapped, the kernels stride.

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 9.38% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title clearly and specifically describes the main change: fixing ROCm FP8 dtype selection and MFMA support on gfx942/gfx950 AMD GPUs, which aligns with the core objectives of the changeset.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@tilelang/utils/__init__.py`:
- Line 3: Remove the unused "# noqa: F401" from the import line in
tilelang/utils/__init__.py so Ruff no longer flags the directive as unused;
locate the line that imports determine_target, select_fp8_e4m3_dtype, and
select_torch_fp8_e4m3_dtype and delete the trailing "# noqa: F401"
(alternatively, if you intended to silence F401, enable F401 in the Ruff config
instead).

In `@tilelang/utils/target.py`:
- Around line 73-79: The dtype-selection logic currently queries device 0 via
torch.cuda.get_device_properties(0); change it to use the active CUDA/HIP device
by calling torch.cuda.current_device() (or equivalent) and pass that index into
torch.cuda.get_device_properties so the gcn_arch check (gcnArchName) reflects
the currently selected GPU; update the block in tilelang/utils/target.py where
torch.version.hip, torch.cuda.is_available(), props =
torch.cuda.get_device_properties(0), and gcn_arch.startswith("gfx950") are used
to instead call torch.cuda.get_device_properties(current_device) (using
torch.cuda.current_device()) before inspecting gcnArchName.
🧹 Nitpick comments (1)
src/tl_templates/hip/hip_fp8.h (1)

67-79: Consider adding a float constructor for API symmetry.

fp8_e5_t lacks a constructor from float while fp8_e4_t provides one (lines 43-51). If this asymmetry is intentional for the current use cases, this is fine. Otherwise, consider adding it for API consistency:

♻️ Optional: Add float constructor to fp8_e5_t
 struct fp8_e5_t {
   unsigned char data;
   __device__ fp8_e5_t() {}
   __device__ fp8_e5_t(hip_fp8_e5_t val) {
     data = *reinterpret_cast<unsigned char *>(&val);
   }
+  __device__ fp8_e5_t(float val) {
+    data = __hip_cvt_float_to_fp8(val, __HIP_SATFINITE, __HIP_E5M2_FNUZ);
+  }
   __device__ operator hip_fp8_e5_t() const {
     return *reinterpret_cast<const hip_fp8_e5_t *>(&data);
   }
   __device__ operator float() const {
     return static_cast<float>(static_cast<hip_fp8_e5_t>(*this));
   }
 };

"""The profiler and convert to torch utils"""

from .target import determine_target # noqa: F401
from .target import determine_target, select_fp8_e4m3_dtype, select_torch_fp8_e4m3_dtype # noqa: F401
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove the unused # noqa: F401 to satisfy Ruff.

Ruff flags the directive as unused on this line, which can fail linting. Either drop it or enable F401 in the config.

🧹 Proposed fix
-from .target import determine_target, select_fp8_e4m3_dtype, select_torch_fp8_e4m3_dtype  # noqa: F401
+from .target import determine_target, select_fp8_e4m3_dtype, select_torch_fp8_e4m3_dtype
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from .target import determine_target, select_fp8_e4m3_dtype, select_torch_fp8_e4m3_dtype # noqa: F401
from .target import determine_target, select_fp8_e4m3_dtype, select_torch_fp8_e4m3_dtype
🧰 Tools
🪛 Ruff (0.14.14)

3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

🤖 Prompt for AI Agents
In `@tilelang/utils/__init__.py` at line 3, Remove the unused "# noqa: F401" from
the import line in tilelang/utils/__init__.py so Ruff no longer flags the
directive as unused; locate the line that imports determine_target,
select_fp8_e4m3_dtype, and select_torch_fp8_e4m3_dtype and delete the trailing
"# noqa: F401" (alternatively, if you intended to silence F401, enable F401 in
the Ruff config instead).

Comment on lines +73 to +79
if torch.version.hip is None:
return "float8_e4m3fn"
if not torch.cuda.is_available():
return "float8_e4m3fnuz"
props = torch.cuda.get_device_properties(0)
gcn_arch = getattr(props, "gcnArchName", "")
if gcn_arch.startswith("gfx950"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cd /tmp && find . -type f -name "target.py" | head -20

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

find . -type f -name "target.py" -path "*/tilelang/utils/*"

Repository: tile-ai/tilelang

Length of output: 86


🏁 Script executed:

find . -type f -name "*.py" | grep -i target | head -20

Repository: tile-ai/tilelang

Length of output: 86


🏁 Script executed:

git ls-files | grep target.py

Repository: tile-ai/tilelang

Length of output: 84


🏁 Script executed:

cat -n tilelang/utils/target.py

Repository: tile-ai/tilelang

Length of output: 9539


🏁 Script executed:

rg "select_fp8_e4m3_dtype" --type py -B 2 -A 2

Repository: tile-ai/tilelang

Length of output: 3304


🏁 Script executed:

rg "torch.cuda.get_device_properties" --type py -B 3 -A 3

Repository: tile-ai/tilelang

Length of output: 6868


🏁 Script executed:

web_search
PyTorch torch.cuda.current_device() torch.cuda.get_device_properties() multi-GPU best practices

Repository: tile-ai/tilelang

Length of output: 291


🌐 Web query:

PyTorch torch.cuda.current_device() best practices multi-GPU systems

💡 Result:

What torch.cuda.current_device() means (and why it can surprise you)

  • PyTorch tracks a per-process “currently selected” CUDA device; torch.cuda.current_device() returns that device index. By default it’s typically cuda:0 unless you change it. CUDA tensor allocation defaults to this device. [1][2]
  • If you restrict visibility with CUDA_VISIBLE_DEVICES, then cuda:0 is “GPU 0 among the visible set”, which may map to different physical GPUs in different processes—so seeing current_device()==0 can be totally correct in multi-process setups. [1][5]

Best practices on multi-GPU systems

1) Prefer explicit devices over relying on the “current device”

Use torch.device("cuda:N") (or a passed-in device) and move tensors/modules explicitly (.to(device)). This avoids accidental allocations on the wrong GPU when the current device changes. [1]

2) If you do rely on the current device, set it intentionally (per process)

  • Use torch.cuda.set_device(local_rank) early before creating CUDA tensors (or wrap code in with torch.cuda.device(local_rank):). [1][4]
  • This matters for distributed/NCCL: some collectives use torch.cuda.current_device() to decide which GPU to use, and PyTorch explicitly puts the responsibility on you to ensure each rank has the right GPU selected. [4]

3) With torchrun/DDP: map one process ↔ one GPU

  • torchrun provides LOCAL_RANK; use it to select the GPU and to configure DDP device_ids / output_device. [3]
  • Minimal pattern:
import os, torch
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)

model = model.to(local_rank)
model = torch.nn.parallel.DistributedDataParallel(
    model, device_ids=[local_rank], output_device=local_rank
)

[3][4]

4) Use CUDA_VISIBLE_DEVICES to control which GPUs PyTorch can see

This is often the cleanest way to partition GPUs across jobs/process groups, and it changes the meaning of device indices inside the process (re-indexing from 0). [1][5]

5) When creating new tensors inside forward, don’t depend on the current device

If you need a new tensor “like an existing tensor”, use x.new_* / new_tensor so it inherits device/dtype, rather than torch.* factory calls that can follow the current-device context. [1]

Sources: PyTorch CUDA semantics / best practices [1], torch.cuda API docs [2], torchrun docs [3], torch.distributed docs (NCCL note about current_device) [4], PyTorch forum guidance on current_device, set_device, and CUDA_VISIBLE_DEVICES behavior [5].


Use the current device when querying GPU architecture for dtype selection.

In multi-GPU ROCm/HIP systems, device 0 may not be the active device. When selecting the FP8 dtype based on GPU architecture (gfx950 vs. other), the function must query the current device instead of hardcoding device 0, otherwise it will return the incorrect dtype.

🔧 Proposed fix
-    props = torch.cuda.get_device_properties(0)
+    device = torch.cuda.current_device()
+    props = torch.cuda.get_device_properties(device)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if torch.version.hip is None:
return "float8_e4m3fn"
if not torch.cuda.is_available():
return "float8_e4m3fnuz"
props = torch.cuda.get_device_properties(0)
gcn_arch = getattr(props, "gcnArchName", "")
if gcn_arch.startswith("gfx950"):
if torch.version.hip is None:
return "float8_e4m3fn"
if not torch.cuda.is_available():
return "float8_e4m3fnuz"
device = torch.cuda.current_device()
props = torch.cuda.get_device_properties(device)
gcn_arch = getattr(props, "gcnArchName", "")
if gcn_arch.startswith("gfx950"):
🤖 Prompt for AI Agents
In `@tilelang/utils/target.py` around lines 73 - 79, The dtype-selection logic
currently queries device 0 via torch.cuda.get_device_properties(0); change it to
use the active CUDA/HIP device by calling torch.cuda.current_device() (or
equivalent) and pass that index into torch.cuda.get_device_properties so the
gcn_arch check (gcnArchName) reflects the currently selected GPU; update the
block in tilelang/utils/target.py where torch.version.hip,
torch.cuda.is_available(), props = torch.cuda.get_device_properties(0), and
gcn_arch.startswith("gfx950") are used to instead call
torch.cuda.get_device_properties(current_device) (using
torch.cuda.current_device()) before inspecting gcnArchName.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py`:
- Around line 237-240: The selected fp8 helpers (select_fp8_e4m3_dtype and
select_fp8_e5m2_dtype) return dtype names as strings but
tl_matmul/assert_tl_matmul_correctness expect TileLang dtype objects (e.g.,
T.float8_e4m3fn); update the calls so you convert the returned string to the
TileLang dtype object before passing to assert_tl_matmul_correctness (for
example by resolving the string via the TileLang type namespace, e.g., mapping
or using getattr(T, dtype_name) to get T.float8_e4m3fn), and use those resolved
dtype objects when invoking assert_tl_matmul_correctness and tl_matmul.
🧹 Nitpick comments (1)
src/tl_templates/hip/hip_fp8.h (1)

160-179: Add storage ctor/assignment for fp8_e5_4_t for API parity.

fp8_e4_4_t supports construction and assignment from its storage type (uint32_t); fp8_e5_4_t doesn't. If generated code attempts to use packed storage for E5, this API gap can cause compilation errors. Align fp8_e5_4_t with the fp8_e4_4_t interface by adding a storage constructor and assignment operator.

♻️ Suggested parity additions
 struct __align__(4) fp8_e5_4_t {
   union {
     fp8_e5_4_storage_t data;
     struct {
       fp8_e5_t x;
       fp8_e5_t y;
       fp8_e5_t z;
       fp8_e5_t w;
     };
   };
   __device__ fp8_e5_4_t() {}
+  __device__ fp8_e5_4_t(const fp8_e5_4_storage_t &val) : data(val) {}
   __device__ fp8_e5_4_t(const hip_fp8x4_e5_t &val) {
     data = *reinterpret_cast<const fp8_e5_4_storage_t *>(&val);
   }
   __device__ operator hip_fp8x4_e5_t() const {
     return *reinterpret_cast<const hip_fp8x4_e5_t *>(&data);
   }
+  __device__ fp8_e5_4_t &operator=(const fp8_e5_4_storage_t &val) {
+    data = val;
+    return *this;
+  }
 };

Comment on lines +237 to +240
e4m3_dtype = select_fp8_e4m3_dtype()
assert_tl_matmul_correctness(128, 128, 128, e4m3_dtype, T.float32, T.float32)
e5m2_dtype = select_fp8_e5m2_dtype()
assert_tl_matmul_correctness(128, 128, 128, e5m2_dtype, T.float32, T.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cd /tmp && cat > verify_fp8.py << 'EOF'
import sys
try:
    import tilelang.language as T
    from tilelang.utils import select_fp8_e4m3_dtype, select_fp8_e5m2_dtype
    for fn in (select_fp8_e4m3_dtype, select_fp8_e5m2_dtype):
        d = fn()
        print(f"{fn.__name__}: value={d}, type={type(d).__name__}, repr={repr(d)}")
except Exception as e:
    print(f"Error: {e}", file=sys.stderr)
    import traceback
    traceback.print_exc()
EOF
python verify_fp8.py

Repository: tile-ai/tilelang

Length of output: 146


🏁 Script executed:

# Find the implementations of these selector functions
rg "def select_fp8_e4m3_dtype|def select_fp8_e5m2_dtype" --type py -A 5

Repository: tile-ai/tilelang

Length of output: 818


🏁 Script executed:

# Check what tl_matmul does with in_dtype
rg "def tl_matmul|def assert_tl_matmul" --type py -A 15

Repository: tile-ai/tilelang

Length of output: 37110


🏁 Script executed:

# Check the example file itself
cat -n examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py | sed -n '230,245p'

Repository: tile-ai/tilelang

Length of output: 713


Convert string dtypes from select_fp8_*() to TileLang dtype objects.

select_fp8_e4m3_dtype() and select_fp8_e5m2_dtype() return strings (e.g., "float8_e4m3fn"), but tl_matmul asserts in_dtype against TileLang dtype objects like T.float8_e4m3fn. This will cause an AssertionError at runtime because "float8_e4m3fn" != T.float8_e4m3fn.

Safe normalization
+def _tl_dtype(d):
+    return getattr(T, d) if isinstance(d, str) else d
+
 def main():
-    e4m3_dtype = select_fp8_e4m3_dtype()
+    e4m3_dtype = _tl_dtype(select_fp8_e4m3_dtype())
     assert_tl_matmul_correctness(128, 128, 128, e4m3_dtype, T.float32, T.float32)
-    e5m2_dtype = select_fp8_e5m2_dtype()
+    e5m2_dtype = _tl_dtype(select_fp8_e5m2_dtype())
     assert_tl_matmul_correctness(128, 128, 128, e5m2_dtype, T.float32, T.float32)
🤖 Prompt for AI Agents
In `@examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py` around lines 237 -
240, The selected fp8 helpers (select_fp8_e4m3_dtype and select_fp8_e5m2_dtype)
return dtype names as strings but tl_matmul/assert_tl_matmul_correctness expect
TileLang dtype objects (e.g., T.float8_e4m3fn); update the calls so you convert
the returned string to the TileLang dtype object before passing to
assert_tl_matmul_correctness (for example by resolving the string via the
TileLang type namespace, e.g., mapping or using getattr(T, dtype_name) to get
T.float8_e4m3fn), and use those resolved dtype objects when invoking
assert_tl_matmul_correctness and tl_matmul.

@hubertlu-tw hubertlu-tw changed the title [AMD] Fix gfx950 FP8 E4M3 selection in AMD FP8 examples [AMD] Fix ROCm FP8 dtype selection and MFMA support on gfx942/gfx950 Jan 28, 2026
Copy link
Member

@LeiWang1999 LeiWang1999 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution, I left a simple comment that would be better to rename the select_fp8_type into determine_fp8_type.

return arch == "arm64"


def select_fp8_e4m3_dtype() -> str:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be better to rename it into determine_fp8_type

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants